#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

from .layer_config import *
from . import conv_blocks
from . import common_blocks


###########################################################
def ConvGWSepNormAct2d(in_planes, out_planes, stride=1, kernel_size=None, groups=1, dilation=1, bias=False, \
                   first_1x1=False, normalization=(DefaultNorm2d,DefaultNorm2d), activation=(DefaultAct2d,DefaultAct2d),
                   shuffle=True, **kwargs):
    num_ch_in_dws = 4
    groups_dws = in_planes//num_ch_in_dws
    layers = [conv_blocks.ConvNormAct2d(in_planes, in_planes, groups=groups_dws, kernel_size=kernel_size, bias=bias, dilation=dilation,
                normalization=normalization[0], activation=activation[0])]
    ###########################################################################            
    #shuffle between first 3x3 and 1x1 with group = Ni/4
    #Example Ni = 64, No= 64, G = 4
    #3x3 with G = 16, will be on [0,1,2,3], [4,5,6,7]……[60,61,62,63]
    #After shuffle with G=16
    #Channel order will be [0,4,8,12,16,20,24,….60] [1,5,9,13,……61], [2,6,10,14,……62], [3,7,11,15,……63]
    #Then 1x1 with G = 4 they will be rightly mixed.
    ###########################################################################            
    if shuffle and (groups != 1):
        layers += [common_blocks.ShuffleBlock(groups=groups_dws)]

    layers += [conv_blocks.ConvNormAct2d(in_planes, out_planes, groups=groups, kernel_size=1, bias=bias,
                normalization=normalization[1], activation=activation[1])]

    layers = torch.nn.Sequential(*layers)
    return layers


######################################################
# this is called a lite block because the dilated convolutions use
# ConvDWNormAct2d instead of ConvDWSepNormAct2d
class GWASPPLiteBlock(torch.nn.Module):
    def __init__(self, in_chs, aspp_chs, out_chs, dilation=(6, 12, 18), groups=1, avg_pool=False, activation=DefaultAct2d, linear_dw=False):
        super().__init__()

        self.aspp_chs = aspp_chs
        self.avg_pool = avg_pool
        self.last_chns = aspp_chs * (4 + (1 if self.avg_pool else 0))

        if self.avg_pool:
            self.gave_pool = torch.nn.Sequential(activation(inplace=False), torch.nn.AdaptiveAvgPool2d((1, 1)),
                                           torch.nn.Conv2d(in_chs, aspp_chs, kernel_size=1), activation(inplace=True))
        #

        self.conv1x1 = conv_blocks.ConvNormAct2d(in_chs, aspp_chs, kernel_size=1, activation=activation)
        normalizations_dw = ((not linear_dw), True)
        activations_dw = (False if linear_dw else activation, activation)
        self.aspp_bra1 = ConvGWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[0], normalization=normalizations_dw, activation=activations_dw, groups = groups)
        self.aspp_bra2 = ConvGWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[1], normalization=normalizations_dw, activation=activations_dw, groups = groups)
        self.aspp_bra3 = ConvGWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[2], normalization=normalizations_dw, activation=activations_dw, groups = groups)

        self.dropout = torch.nn.Dropout2d(p=0.2, inplace=True)
        self.aspp_out = conv_blocks.ConvNormAct2d(self.last_chns, out_chs, kernel_size=1, groups=1, activation=activation)
        self.cat = common_blocks.CatBlock()

    def forward(self, x):
        x1 = self.conv1x1(x)
        b1 = self.aspp_bra1(x)
        b2 = self.aspp_bra2(x)
        b3 = self.aspp_bra3(x)

        if self.avg_pool:
            xavg = F.interpolate(self.gave_pool(self.aspp_in(x)), size=x.shape[2:], mode='bilinear')
            branches = [xavg, x1, b1, b2, b3]
        else:
            branches = [x1, b1, b2, b3]
        #

        cat = self.cat(branches)
        cat = self.dropout(cat)
        out = self.aspp_out(cat)
        return out
#


